{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|default_exp models.HydraPlus"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# HydraPlus"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">Hydra: competing convolutional kernels for fast and accurate time series classification.\n",
    "\n",
    "This is a Pytorch implementation of Hydra adapted by Ignacio Oguiza and based on:\n",
    "\n",
    "Dempster, A., Schmidt, D. F., & Webb, G. I. (2023). Hydra: Competing convolutional kernels for fast and accurate time series classification. Data Mining and Knowledge Discovery, 1-27.\n",
    "\n",
    "Original paper: https://link.springer.com/article/10.1007/s10618-023-00939-3\n",
    "\n",
    "Original repository:  https://github.com/angus924/hydra"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "from collections import OrderedDict\n",
    "from typing import Any\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from tsai.imports import default_device\n",
    "from tsai.models.layers import Flatten, rocket_nd_head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class HydraBackbonePlus(nn.Module):\n",
    "\n",
    "    def __init__(self, c_in, c_out, seq_len, k = 8, g = 64, max_c_in = 8, clip=True, device=default_device(), zero_init=True):\n",
    "\n",
    "        super().__init__()\n",
    "\n",
    "        self.k = k # num kernels per group\n",
    "        self.g = g # num groups\n",
    "\n",
    "        max_exponent = np.log2((seq_len - 1) / (9 - 1)) # kernel length = 9\n",
    "\n",
    "        self.dilations = 2 ** torch.arange(int(max_exponent) + 1, device=device)\n",
    "        self.num_dilations = len(self.dilations)\n",
    "\n",
    "        self.paddings = torch.div((9 - 1) * self.dilations, 2, rounding_mode = \"floor\").int()\n",
    "\n",
    "        # if g > 1, assign: half the groups to X, half the groups to diff(X)\n",
    "        divisor = 2 if self.g > 1 else 1\n",
    "        _g = g // divisor\n",
    "        self._g = _g\n",
    "        self.W = [self.normalize(torch.randn(divisor, k * _g, 1, 9)).to(device=device) for _ in range(self.num_dilations)]\n",
    "\n",
    "\n",
    "        # combine c_in // 2 channels (2 < n < max_c_in)\n",
    "        c_in_per = np.clip(c_in // 2, 2, max_c_in)\n",
    "        self.I = [torch.randint(0, c_in, (divisor, _g, c_in_per), device=device) for _ in range(self.num_dilations)]\n",
    "\n",
    "        # clip values\n",
    "        self.clip = clip\n",
    "\n",
    "        self.device = device\n",
    "        self.num_features = min(2, self.g) * self._g * self.k * self.num_dilations * 2\n",
    "\n",
    "    @staticmethod\n",
    "    def normalize(W):\n",
    "        W -= W.mean(-1, keepdims = True)\n",
    "        W /= W.abs().sum(-1, keepdims = True)\n",
    "        return W\n",
    "\n",
    "    # transform in batches of *batch_size*\n",
    "    def batch(self, X, split=None, batch_size=256):\n",
    "        bs = X.shape[0]\n",
    "        if bs <= batch_size:\n",
    "            return self(X)\n",
    "        elif split is None:\n",
    "            Z = []\n",
    "            for i in range(0, bs, batch_size):\n",
    "                Z.append(self(X[i:i+batch_size]))\n",
    "            return torch.cat(Z)\n",
    "        else:\n",
    "            Z = []\n",
    "            batches = torch.as_tensor(split).split(batch_size)\n",
    "            for i, batch in enumerate(batches):\n",
    "                Z.append(self(X[batch]))\n",
    "            return torch.cat(Z)\n",
    "\n",
    "    def forward(self, X):\n",
    "\n",
    "        bs = X.shape[0]\n",
    "\n",
    "        if self.g > 1:\n",
    "            diff_X = torch.diff(X)\n",
    "\n",
    "        Z = []\n",
    "\n",
    "        for dilation_index in range(self.num_dilations):\n",
    "\n",
    "            d = self.dilations[dilation_index].item()\n",
    "            p = self.paddings[dilation_index].item()\n",
    "\n",
    "            # diff_index == 0 -> X\n",
    "            # diff_index == 1 -> diff(X)\n",
    "            for diff_index in range(min(2, self.g)):\n",
    "                _Z = F.conv1d(X[:, self.I[dilation_index][diff_index]].sum(2) if diff_index == 0 else diff_X[:, self.I[dilation_index][diff_index]].sum(2),\n",
    "                              self.W[dilation_index][diff_index], dilation = d, padding = p, groups = self._g).view(bs, self._g, self.k, -1)\n",
    "\n",
    "                max_values, max_indices = _Z.max(2)\n",
    "                count_max = torch.zeros(bs, self._g, self.k, device=self.device)\n",
    "\n",
    "                min_values, min_indices = _Z.min(2)\n",
    "                count_min = torch.zeros(bs, self._g, self.k, device=self.device)\n",
    "\n",
    "                count_max.scatter_add_(-1, max_indices, max_values)\n",
    "                count_min.scatter_add_(-1, min_indices, torch.ones_like(min_values))\n",
    "\n",
    "                Z.append(count_max)\n",
    "                Z.append(count_min)\n",
    "\n",
    "        Z = torch.cat(Z, 1).view(bs, -1)\n",
    "\n",
    "        if self.clip:\n",
    "            Z = F.relu(Z)\n",
    "\n",
    "        return Z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class HydraPlus(nn.Sequential):\n",
    "\n",
    "    def __init__(self,\n",
    "        c_in:int, # num of channels in input\n",
    "        c_out:int, # num of channels in output\n",
    "        seq_len:int, # sequence length\n",
    "        d:tuple=None, # shape of the output (when ndim > 1)\n",
    "        k:int=8, # number of kernels per group\n",
    "        g:int=64, # number of groups\n",
    "        max_c_in:int=8, # max number of channels per group\n",
    "        clip:bool=True, # clip values >= 0\n",
    "        use_bn:bool=True, # use batch norm\n",
    "        fc_dropout:float=0., # dropout probability\n",
    "        custom_head:Any=None, # optional custom head as a torch.nn.Module or Callable\n",
    "        zero_init:bool=True, # set head weights and biases to zero\n",
    "        use_diff:bool=True, # use diff(X) as input\n",
    "        device:str=default_device(), # device to use\n",
    "        ):\n",
    "\n",
    "        # Backbone\n",
    "        backbone = HydraBackbonePlus(c_in, c_out, seq_len, k=k, g=g, max_c_in=max_c_in, clip=clip, device=device, zero_init=zero_init)\n",
    "        num_features = backbone.num_features\n",
    "\n",
    "\n",
    "        # Head\n",
    "        self.head_nf = num_features\n",
    "        if custom_head is not None:\n",
    "            if isinstance(custom_head, nn.Module): head = custom_head\n",
    "            else: head = custom_head(self.head_nf, c_out, 1)\n",
    "        elif d is not None:\n",
    "            head = rocket_nd_head(num_features, c_out, seq_len=None, d=d, use_bn=use_bn, fc_dropout=fc_dropout, zero_init=zero_init)\n",
    "        else:\n",
    "            layers = [Flatten()]\n",
    "            if use_bn:\n",
    "                layers += [nn.BatchNorm1d(num_features)]\n",
    "            if fc_dropout:\n",
    "                layers += [nn.Dropout(fc_dropout)]\n",
    "            linear = nn.Linear(num_features, c_out)\n",
    "            if zero_init:\n",
    "                nn.init.constant_(linear.weight.data, 0)\n",
    "                nn.init.constant_(linear.bias.data, 0)\n",
    "            layers += [linear]\n",
    "            head = nn.Sequential(*layers)\n",
    "\n",
    "        super().__init__(OrderedDict([('backbone', backbone), ('head', head)]))\n",
    "\n",
    "Hydra = HydraPlus"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([16, 3])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xb = torch.randn(16, 5, 20).to(default_device())\n",
    "yb = torch.randint(0, 3, (16, 20)).to(default_device())\n",
    "\n",
    "model = HydraPlus(5, 3, 20, d=None).to(default_device())\n",
    "output = model(xb)\n",
    "assert output.shape == (16, 3)\n",
    "output.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([16, 3])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xb = torch.randn(16, 5, 20).to(default_device())\n",
    "yb = torch.randint(0, 3, (16, 20)).to(default_device())\n",
    "\n",
    "model = HydraPlus(5, 3, 20, d=None, use_diff=False).to(default_device())\n",
    "output = model(xb)\n",
    "assert output.shape == (16, 3)\n",
    "output.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([16, 20, 3])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xb = torch.randn(16, 5, 20).to(default_device())\n",
    "yb = torch.randint(0, 3, (16, 5, 20)).to(default_device())\n",
    "\n",
    "model = HydraPlus(5, 3, 20, d=20, use_diff=True).to(default_device())\n",
    "output = model(xb)\n",
    "assert output.shape == (16, 20, 3)\n",
    "output.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": "IPython.notebook.save_checkpoint();",
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/Users/nacho/notebooks/tsai/nbs/079_models.HydraPlus.ipynb saved at 2024-02-10 22:16:56\n",
      "Correct notebook to script conversion! 😃\n",
      "Saturday 10/02/24 22:16:59 CET\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "                <audio  controls=\"controls\" autoplay=\"autoplay\">\n",
       "                    <source src=\"data:audio/wav;base64,UklGRvQHAABXQVZFZm10IBAAAAABAAEAECcAACBOAAACABAAZGF0YdAHAAAAAPF/iPh/gOoOon6w6ayCoR2ZeyfbjobxK+F2Hs0XjKc5i3DGvzaTlEaraE+zz5uLUl9f46fHpWJdxVSrnfmw8mYEScqUP70cb0Q8X41uysJ1si6Eh1jYzXp9IE2DzOYsftYRyoCY9dJ/8QICgIcEun8D9PmAaBPlfT7lq4MFIlh61tYPiCswIHX+yBaOqT1QbuW7qpVQSv9lu6+xnvRVSlyopAypbGBTUdSalrSTaUBFYpInwUpxOzhti5TOdndyKhCGrdwAfBUcXIJB69p+Vw1egB76+n9q/h6ADglbf4LvnIHfF/981ODThF4m8HiS0riJVjQ6c+/EOZCYQfJrGrhBmPVNMmNArLKhQlkXWYqhbaxXY8ZNHphLuBJsZUEckCTFVHMgNKGJytIDeSUmw4QN4Qx9pReTgb3vYX/TCBuApf75f+P5Y4CRDdN+B+tngk8c8nt03CKGqipgd13OhotwOC5x9MCAknFFcmlmtPmagFFFYOCo0qRzXMhVi57pryNmIEqJlRi8bm52PfuNM8k4dfQv+4cO12l6zCGdg3jl730uE/KAPvS+f0wEAoAsA89/XfXQgBESIn6S5luDtiC8eh/YmIfpLqt1OMp5jXg8/24MveqUNUnPZsqw0Z3yVDldnaUOqIZfXlKrm36zzWhjRhaT+r+ncHI5/otUzfd2uSt7hl/bqXtoHaCC6+mqfrAOeoDD+PJ/xf8RgLMHfH/b8GeBihZIfSXidoQSJWB52NM1iRkzz3MkxpKPbUCrbDu5d5fgTAxkSK3JoEhYD1p2omere2LZTuqYLbdWa49Cx5Dww7tyXDUnioXRkHhwJyKFvd/AfPoYy4Fl7j1/LQorgEr9/X89+0qAOAwAf13sJoL8Gkd8wt25hWIp3Heez/eKODfPcSPCzpFNRDVqf7UlmnNQKGHgqd+jgVvJVm2f265QZTpLS5byur1tpT6ajvrHq3Q2MXWIxtUCehoj8YMk5LB9hRQegeTypn+nBQWA0QHgf7f2q4C5EFt+5ucOg2YfHXtq2SSHpS0ydnTL4IxFO6pvNb4ulBdInWfcsfSc7VMmXpSmE6eeXmZThJxpsgRohEfOk86+AHCoOpOMFsx1dv8s6oYT2k17uR7ngpXod34IEJqAaPfnfyABCIBZBpl/NPI2gTQVjX134x2ExSPMeR7VtYjZMWJ0W8ftjkA/YW1durCWykvjZFKu4p9LVwVbZKNkqpxh6U+6mRC2mGq2Q3SRvsIgcpc2sIpD0Bp4uiiFhW3ecXxOGgaCDe0Vf4cLPoDv+/5/mfw1gN4KKX+17emBqBmYfBHfVYUZKFR44NBtiv41bHJUwx+RJkP1apu2VJlkTwli4qrwoo1ax1dToNCtemRSTBGXz7kJbdM/PY/Dxht0dTLziH7Ul3loJEiE0uJsfdsVTYGL8Yt/AgcMgHYA7X8S+IqAYA+QfjzpxIIVHnp7tdqzhmAstXaxzEqMETpScGC/dJP3Rmdo8LIZnOVSEF+Opxumsl1sVF+dVrE5Z6NIiZSkvVdv2zsqjdnK8HVDLlyHyNjuegogM4NA5z9+YRG9gA722H97AgOA/gSyf43zCIHdE899yuTIg3ciNXpm1jmImTDwdJPITI4RPhRugbvslbFKt2Vfr/6eTFb4W1WkY6m6YPdQjJr2tNZp3EQlko7BgXHRNz2LAc+gdwMq7IUf3R58ohtFgrbr6n7hDFWAlPr8f/T9I4CECU9/De+vgVQY5nxh4POEzybJeCTS5YnCNAZzhsRzkP1Bsmu4t4aYU07nYuerA6KWWcJYO6HHrKJjaE3Zl624UWz/QOOPjcWHc7QzdIk40yl5tCWjhIDhJX0xF4CBMvBsf10IF4Ac//Z/bPlsgAcOwn6S6n6CwxzUewLcRoYaKzV38M23i9o493CNwL6S1UUuaQe0QpvbUfdfiqglpcRccFU+nkWwambASUiVfLyqbg49xY2eyWh1hy/Sh37XjHpaIYKD7OUEfrgS5IC09MV/1gMBgKMDyH/n9N6AhhINfh7mdoMoIZt6r9fAh1cvfHXNya6N4DzDbqi8K5WWSYlmbbAdnkpV6FxJpWSo1V8DUmGb3rMRaQBG2JJgwN9wCDnNi8HNI3dKK1aG0dvHe/UciIJf6rt+Og5wgDn59X9P/xWAKQhxf2XweYH+FjB9suGVhIMlOnlo02GJhTOdc7vFyo/TQGxs2Li7lz9NwmPurBihnVi7WSWiwKvGYntOpJiOt5drKUKMkFnE8HLxNPmJ9NG4eP8mAYUv4Np8hhi3gdruSX+3CSWAwP38f8f6UoCuDPF+6Os8gnAbKnxQ3d2F0imydzDPKIuiN5lxu8EKkrFE82kftW2az1DbYImpMqTUW3FWIJ83r5hl2koJlla7+m0+PmSOZcjcdMgwS4g11iZ6qCLUg5jkxn0QFA6BWvOvfzEFBIBHAtp/Qfa3gC4RSH5y5yeD2B/8evnYS4cULgR2CMsUja47cG/QvW6UeEhXZ3+xP51GVNVdP6Zpp+1eDFM5nMeySWghR4+TNL85cD46YIyCzKJ2kCzEhoTabXtGHs+CCemJfpMPjoDe9+t/qQALgM8Gj3++8UaBqRV2fQTjO4Q3JKd5r9TgiEYyMHTxxiWPpz8jbfq585YpTJpk960xoKFXsVoTo7yq6GGMTw==\" type=\"audio/wav\" />\n",
       "                    Your browser does not support the audio element.\n",
       "                </audio>\n",
       "              "
      ],
      "text/plain": [
       "<IPython.lib.display.Audio object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#|eval: false\n",
    "#|hide\n",
    "from tsai.export import get_nb_name; nb_name = get_nb_name(locals())\n",
    "from tsai.imports import create_scripts; create_scripts(nb_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
